package com.example.demo.config.filter;

import com.example.demo.utils.MessageCodeEnum;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PropertiesLoaderUtils;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Properties;

/**
 * 过滤不信任的请求的host
 */
@WebFilter
public class TrustHostsFilter implements Filter {
    private transient static final Logger LOGGER = LoggerFactory.getLogger(TrustHostsFilter.class);

    private String[] trustHosts;

    @Value("${config.trustHost}")
    private boolean trustHost;

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        if (trustHost && trustHosts != null && trustHosts.length > 0) {
            HttpServletRequest httpServletRequest = (HttpServletRequest) request;
            HttpServletResponse httpServletResponse = (HttpServletResponse) response;
            String host1 = httpServletRequest.getHeader("Host");
            String host2 = httpServletRequest.getServerName();
            if (!this.isTrustHost(host1) || !this.isTrustHost(host2)) {
                LOGGER.info("当前host[{},{}]不是信任的host，禁止访问", host1, host2);
                httpServletResponse.setStatus(MessageCodeEnum.CODE_403.getCode());
            } else {
                chain.doFilter(request, response);
            }
        } else {
            chain.doFilter(request, response);
        }
    }

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        try {
            Resource resource = new ClassPathResource("config" + File.separator + "security" + File.separator + "trusthosts.application");
            Properties props = PropertiesLoaderUtils.loadProperties(resource);
            String hosts = props.getProperty("trustHosts");
            if (hosts != null) {
                trustHosts = hosts.split(",");
            }
        } catch (Exception e) {
            LOGGER.error("加载TrustHosts属性文件报错", e);
        }
    }

    @Override
    public void destroy() {
        if (trustHosts != null) {
            trustHosts = null;
        }
    }

    public boolean isTrustHost(String host) {
        boolean isTrustHost = false;
        if (trustHosts != null && trustHosts.length > 0) {
            for (String trustHost : trustHosts) {
                if (host.indexOf(trustHost) != -1) {
                    return true;
                }
            }
        }
        return isTrustHost;
    }
}
